import os
import pickle
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as datasets

from data.dl_getter import get_transform, normalize

home_dir = os.path.expanduser("~")
ood_root = {
    'dtd' : os.path.join(home_dir, 'data/dtd/images'),
    'iSUN' : os.path.join(home_dir, 'data/iSUN'),
    'LSUN': os.path.join(home_dir, 'data/LSUN'),
    'LSUN_R': os.path.join(home_dir, 'data/LSUN_resize'),
    'places365': os.path.join(home_dir, 'data/places365'),
    'N': os.path.join(home_dir, 'data/non_natural/N.pkl'),
    'U': os.path.join(home_dir, 'data/non_natural/U.pkl'),
    'OODomain': os.path.join(home_dir, 'data/non_natural/OODomain.pkl'),
    'Constant': os.path.join(home_dir, 'data/non_natural/Constant.pkl'), 
}


class Robust_ds(Dataset):
    def __init__(self, root_dir='~/data'):
        self.root_dir = root_dir
        self.transform = get_transform()
        self.data = datasets.CIFAR10(
            root=root_dir, train=False, download=False).data[:1000]
        self.targets = datasets.CIFAR10(
            root=root_dir, train=False, download=False).targets[:1000]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image, label = self.data[idx], self.targets[idx]
        if self.transform:
            image = self.transform(image)
        return image, label
    

class Non_dataset(Dataset):
    def __init__(self, type):
        self.data = pickle.load(open(ood_root[type], 'rb'))

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image = self.data[idx]
        return image, -1